import torch
from torch import nn
import math
import warnings
from torch.nn import init
import numpy as np
from uniperceiver.utils import comm

INIT_STD = 0.02
INIT_EMBEDDING_STD = 0.02

def null_loss_check(outputs_dict):
    ret = {}
    if 'null_loss' in outputs_dict:
        null_loss = outputs_dict['null_loss']
    else:
        null_loss = 0
    for shared_target in outputs_dict['shared_target_sets'].values():
        null_loss += torch.sum(shared_target[0]['data']*0)
    ret.update({'null_loss': null_loss})
    return ret

def build_2d_sincos_position_embedding(cfg, video_embed, cls_token=False, temperature=10000., pos_emd_fix=False):
    h, w = int(video_embed.max_spatial_size**.5), int(video_embed.max_spatial_size**.5)

    grid_w = torch.arange(w, dtype=torch.float32)
    grid_h = torch.arange(h, dtype=torch.float32)
    grid_w, grid_h = torch.meshgrid(grid_w, grid_h)
    if cfg.MODEL.POSEMBED_SCALE != 1.0:
        grid_w = grid_w * cfg.MODEL.POSEMBED_SCALE
        grid_h = grid_h * cfg.MODEL.POSEMBED_SCALE

    assert cfg.MODEL.BERT.HIDDEN_SIZE % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
    pos_dim = cfg.MODEL.BERT.HIDDEN_SIZE // 4
    omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
    omega = 1. / (temperature**omega)
    out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])
    out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
    pos_emb = torch.cat([
        torch.sin(out_w),
        torch.cos(out_w),
        torch.sin(out_h),
        torch.cos(out_h)
    ],
                        dim=1)[ :, :]

    # assert self.num_tokens == 1, 'Assuming one and only one token, [cls]'
    if cls_token:
        pe_token = torch.zeros([ 1, cfg.MODEL.BERT.HIDDEN_SIZE], dtype=torch.float32)
        video_embed.embeddings_st_pos.spatial_pos_embed.weight = nn.Parameter(torch.cat([pe_token, pos_emb], dim=0))
    else:
        video_embed.embeddings_st_pos.spatial_pos_embed.weight = nn.Parameter(pos_emb)
    if cfg.MODEL.POSEMBEDFIX:
        video_embed.embeddings_st_pos.spatial_pos_embed.weight.requires_grad = False


def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn(
            "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
            "The distribution of values may be incorrect.",
            stacklevel=2)

    with torch.no_grad():
        # Values are generated by using a truncated uniform distribution and
        # then using the inverse CDF for the normal distribution.
        # Get upper and lower cdf values
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)

        # Uniformly fill tensor with values from [l, u], then translate to
        # [2l-1, 2u-1].
        tensor.uniform_(2 * l - 1, 2 * u - 1)

        # Use inverse cdf transform for normal distribution to get truncated
        # standard normal
        tensor.erfinv_()

        # Transform to proper mean, std
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)

        # Clamp to ensure it's in the proper range
        tensor.clamp_(min=a, max=b)
        return tensor

def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    return _no_grad_trunc_normal_(tensor, mean, std, a, b)

def truncated_normal_(tensor, mode='fan_in',):
    # with FSDP, module params will be on CUDA, so we cast them back to CPU
    # so that the RNG is consistent with and without FSDP
    fan = init._calculate_correct_fan(tensor, mode=mode)
    gain = 0.1
    std = math.sqrt(gain/fan)
    init.trunc_normal_(tensor, mean=0.0, std=std)

def normal_(data):
    # with FSDP, module params will be on CUDA, so we cast them back to CPU
    # so that the RNG is consistent with and without FSDP
    data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))

def init_bert_params(module):
    if isinstance(module, nn.Linear):
        normal_(module.weight.data)
        if module.bias is not None:
            module.bias.data.zero_()
    if isinstance(module, nn.Embedding):
        normal_(module.weight.data)
        if module.padding_idx is not None:
            module.weight.data[module.padding_idx].zero_()
    if isinstance(module, nn.MultiheadAttention):
        # normal_(module.q_proj.weight.data)
        # normal_(module.k_proj.weight.data)
        # normal_(module.v_proj.weight.data)
        normal_(module.in_proj_weight.data)

def init_switchtransformer_params(module):
    if isinstance(module, nn.Linear):
        truncated_normal_(module.weight)
        if module.bias is not None:
            module.bias.data.zero_()
    if isinstance(module, nn.Embedding):
        normal_(module.weight.data)
        if module.padding_idx is not None:
            module.weight.data[module.padding_idx].zero_()


def init_timm_params(m):
    if isinstance(m, nn.Linear):
        trunc_normal_(m.weight, std=INIT_STD)
        if  m.bias is not None:
            nn.init.constant_(m.bias, 0)
    if isinstance(m, nn.Embedding):
        trunc_normal_(m.weight.data, std=INIT_EMBEDDING_STD)
        if m.padding_idx is not None:
            m.weight.data[m.padding_idx].zero_()
    if isinstance(m, nn.MultiheadAttention):
        trunc_normal_(m.q_proj.weight.data, std=INIT_STD)
        trunc_normal_(m.k_proj.weight.data, std=INIT_STD)
        trunc_normal_(m.v_proj.weight.data, std=INIT_STD)

def initialize_weights_as_mae(model):
    # initialization

    # initialize nn.Linear and nn.LayerNorm
    model.apply(init_weights_mae)

    # initialize (and freeze) pos_embed by sin-cos embedding
    if model.video_embed is not None:
        build_2d_sincos_position_embedding(model.cfg, model.video_embed)


    # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
    w = model.video_embed.embeddings.weight.data
    torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

    if model.video_embed.embeddings.bias is not None:
        nn.init.zeros_(model.video_embed.embeddings.bias)


def initialize_weights_as_mocov3(model):
    model.initialize_weights_as_mae()

    # cls token with smaller std
    # temp = torch.zeros([ 1, self.cfg.MODEL.BERT.HIDDEN_SIZE], dtype=torch.float32)
    nn.init.normal_(model.token_embed.embeddings.weight[-1, :], std=1e-6) # small std for cls token


def init_weights_mae(m):
    # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
    # torch.nn.init.normal_(self.cls_token, std=.02)
    if isinstance(m, nn.Linear):
        # we use xavier_uniform following official JAX ViT:

        if m.weight.shape[0] == m.weight.shape[1] * 3:
            # treat the weights of Q, K, V separately
            val = math.sqrt(6. / float(m.weight.shape[0] // 3 + m.weight.shape[1]))
            nn.init.uniform_(m.weight, -val, val)
        else:
            torch.nn.init.xavier_uniform_(m.weight)
        if isinstance(m, nn.Linear) and m.bias is not None:
            nn.init.constant_(m.bias, 0)

    # all word embedding e.g. word. spe. type embedding postion embed
    # MAE only  has embedding like cls_token and mask tokens
    elif isinstance(m, nn.Embedding):
        torch.nn.init.normal_(m.weight.data, std=INIT_EMBEDDING_STD)
        if m.padding_idx is not None:
            m.weight.data[m.padding_idx].zero_()


    elif isinstance(m, nn.LayerNorm):
        nn.init.constant_(m.bias, 0)
        nn.init.constant_(m.weight, 1.0)

    elif isinstance(m, nn.MultiheadAttention):
        if m.q_proj_weight is not None:
            torch.nn.init.xavier_uniform_(m.q_proj_weight.data)
            torch.nn.init.xavier_uniform_(m.k_proj_weight.data)
            torch.nn.init.xavier_uniform_(m.v_proj_weight.data)
        else:
            # treat the weights of Q, K, V separately
            val = math.sqrt(6. / float(m.in_proj_weight.shape[0] // 3 + m.in_proj_weight.shape[1]))
            nn.init.uniform_(m.in_proj_weight, -val, val)

def data_half(fp16, bf16, data):
    if fp16:
        for k, v in data.items():
            if isinstance(v, torch.Tensor) and v.dtype == torch.float32:
                data[k] = v.half()
                # print(k)

    elif bf16:
        for k, v in data.items():
            if isinstance(v, torch.Tensor) and v.dtype == torch.float32:
                data[k] = v.to(torch.bfloat16)
                # print(k)

    return data

def postprocess(data_dict:dict, task_info:dict ):
    if data_dict.get('sample_info', None) is not None and data_dict['sample_info'].get('distributed', False):
        data = data_dict['data']
        hidden_states = data[:, 0].contiguous(
            )  # HERE only use the spe token feature!
        hidden_states = torch.cat(torch.distributed.nn.all_gather(hidden_states))

        total_length = data_dict['sample_info']['total_num']

        if hidden_states.shape[0] > total_length:
            hidden_states = hidden_states[:total_length]

        data_dict['data'] = hidden_states.unsqueeze(1)


def get_spe_token(tokenizer, token_embed):
    if comm.old_checkpoint:
        a = torch.tensor(tokenizer.encode('<|spe|>')).cuda().unsqueeze(0)  # bs, 1
        return token_embed(a, type_embed=False, pos_embed=False)
    else:
        a = torch.tensor(tokenizer.encode('spe')).cuda().unsqueeze(0) # bs, 1
        return token_embed(a)

def preprocess(tokenizer, token_embed, data_list:list, task_info:dict):
    # perparation for fused_encoder input
    bs = data_list[0]['data'].shape[0]
    device =  data_list[0]['data'].device
    mask_dtype = torch.uint8

    #TODO: prompt embedding

    prefix_spe_before_fuse = task_info.get('prefix_spe_before_fuse', True)

    combined_data = []
    # spe embedding
    spe_token = get_spe_token(tokenizer, token_embed).expand(bs, -1, -1)

    length = [ data_dict['data'].shape[1] for data_dict in data_list]
    if prefix_spe_before_fuse:
        length = [1] + length
        combined_data.append(spe_token)

    cum_length = np.cumsum(length).tolist()

    invalid_mask_active = any([ data_dict.get('invalid_mask', None) is not None for data_dict in data_list])
    if invalid_mask_active:

        combined_valid_mask =  torch.zeros((bs, cum_length[-1]), dtype=mask_dtype, device=device)
    else:
        combined_valid_mask = None

    for i, data_dict in enumerate(data_list):
        combined_data.append(data_dict['data'])
        if  data_dict.get('invalid_mask', None) is not None:
            combined_valid_mask[:, cum_length[i]:cum_length[i+1]] = data_dict['invalid_mask']

    combined_data = torch.cat(combined_data, dim=1)

    sample_info = {
            'data_length': length,
            'data_cum_length': cum_length,
            'sample_info_per_sample': []}

    # for caption task inference
    if comm._CAPTION_GEN_MODE:
        sample_info['data_cum_length'] = data_list[0]['sample_info']['data_cum_length']


    for data_dict in data_list:
        if data_dict.get('sample_info', None) is not None:
            if isinstance(data_dict['sample_info'], dict):
                sample_info.update(data_dict['sample_info'])
            elif isinstance(data_dict['sample_info'], list):
                if isinstance(data_dict['sample_info'][0], dict):
                    sample_info.update(data_dict['sample_info'][0])
                sample_info['sample_info_per_sample'].append(data_dict['sample_info'])

    moe_embedding = None
    for data_dict in data_list:
        if 'data_type' in data_dict:
            data_type = data_dict['data_type']
        if 'moe_embedding' in data_dict:
            moe_embedding = data_dict['moe_embedding']

    return {
        'data': combined_data,
        'invalid_mask': combined_valid_mask,
        'data_type': data_type,
        'sample_info': sample_info,
        'moe_embedding': moe_embedding,
    }

def share_token_embed_ln(video_embed, token_embed):
    if  video_embed is not None and  token_embed is not None:
        del video_embed.embeddings_norm
        video_embed.embeddings_norm = token_embed.embeddings_norm
